# This script takes the raw labels from each model and generates performance metrics
import pandas as pd
import numpy as np
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from sklearn.utils import resample
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

test = pd.read_csv('./data/polnli_test_results.csv')

#####################
## Utility Functions
#####################
def metrics(df, preds, group_by=None):
    """
    Calculate MCC, Accuracy, F1 for predictions.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: DataFrame with calculated metrics, optionally grouped by `group_by`.
    """
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='macro')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

def bootstrapped_errors(y_true, y_pred, n_bootstrap=1000):
    """
    Calculate bootstrapped standard errors for MCC, Accuracy, and F1.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Predicted labels.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.

    Returns:
        dict: Standard errors for MCC, Accuracy, and F1.
    """
    mcc_scores = []
    accuracy_scores = []
    f1_scores = []
    
    for _ in range(n_bootstrap):
        # Resample with replacement
        y_true_resampled, y_pred_resampled = resample(y_true, y_pred)
        
        # Calculate metrics for the resampled data
        mcc_scores.append(matthews_corrcoef(y_true_resampled, y_pred_resampled))
        accuracy_scores.append(accuracy_score(y_true_resampled, y_pred_resampled))
        f1_scores.append(f1_score(y_true_resampled, y_pred_resampled, average='weighted'))
    
    # Calculate standard errors
    return {
        'MCC_SE': np.std(mcc_scores),
        'Accuracy_SE': np.std(accuracy_scores),
        'F1_SE': np.std(f1_scores)
    }

def metrics_with_errors(df, preds, n_bootstrap=1000, group_by=None):
    """
    Calculate metrics and bootstrapped standard errors for predictions, optionally grouped.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: Combined DataFrame of metrics, standard errors, and confidence intervals.
    """
    # Step 1: Calculate metrics for each model
    metrics_df = metrics(df, preds, group_by=group_by)

    # Step 2: Calculate bootstrapped errors for each model or group
    errors = []
    if group_by not in ['dataset', 'task']:
        for col in preds:
            y_true = df['entailment']
            y_pred = df[col]
            errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
            errors_dict['Column'] = col
            errors.append(errors_dict)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                y_true = group['entailment']
                y_pred = group[col]
                errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
                errors_dict['Column'] = col
                errors_dict[group_by.capitalize()] = group_name
                errors.append(errors_dict)

    errors_df = pd.DataFrame(errors)

    if group_by in ['dataset', 'task']:
        errors_df = errors_df.set_index(['Column', group_by.capitalize()])
    else:
        errors_df = errors_df.set_index('Column')

    # Step 3: Merge metrics and errors DataFrames
    combined_df = metrics_df.merge(errors_df, left_index=True, right_index=True)

    # Step 4: Calculate confidence intervals (upper and lower bounds)
    combined_df['MCC_Lower'] = combined_df['MCC'] - combined_df['MCC_SE']
    combined_df['MCC_Upper'] = combined_df['MCC'] + combined_df['MCC_SE']

    combined_df['Accuracy_Lower'] = combined_df['Accuracy'] - combined_df['Accuracy_SE']
    combined_df['Accuracy_Upper'] = combined_df['Accuracy'] + combined_df['Accuracy_SE']

    combined_df['F1_Lower'] = combined_df['F1'] - combined_df['F1_SE']
    combined_df['F1_Upper'] = combined_df['F1'] + combined_df['F1_SE']

    return combined_df


# Models to calculate metrics for
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern',
           'llama',
           'llama70b',
           'sonnet']

# Calculate metrics for various partitions of the test set
overall = metrics_with_errors(test, columns, group_by = None)
task = metrics_with_errors(test, columns, group_by = 'task')
dataset = metrics_with_errors(test, columns, group_by = 'dataset')

# rename columns and concat results
overall.reset_index(inplace = True, drop = False)
overall['Task'] = 'overall'
overall.rename({'Column':'Model'}, axis = 1, inplace = True)

task.reset_index(inplace = True, drop = False)
task.rename({'Column':'Model'}, axis = 1, inplace = True)

dataset.reset_index(inplace = True, drop = False)
dataset.rename({'Dataset':'Task'}, axis = 1, inplace = True)
dataset['Task'] = dataset['Task'].str.replace('mlburnham/', '')
dataset.rename({'Column':'Model'}, axis = 1, inplace = True)

# Combine results into a single matrix
combined = pd.concat([overall, task, dataset])
combined.to_csv('./data/results_matrix.csv', index = False)

###########
## Figures
###########
df = pd.read_csv('./data/results_matrix.csv')
# rename models for plotting
df.replace({'base_nli': 'NLI Base', 
            'large_nli':'NLI Large', 
            'base_debate':'DEBATE Base',
            'llama':'Llama 3.1 8B',
            'llama70b':'Llama 3.3 70B',
            'large_debate':'DEBATE Large',
            'sonnet':'Claude 3.5',
            'base_modern': 'DEBATE Base (MB)',
            'large_modern': 'DEBATE Large (MB)',
            'event extraction': 'Events',
            'hatespeech and toxicity':'Hatespeech',
            'stance detection':'Stance',
            'topic classification':'Topic'
           }, inplace = True)

# subset to desired models
df = df[~df['Model'].str.contains('(MB)')]
df.reset_index(drop = True, inplace = True)

############
## Figure 1
############
metric = 'F1'
task = 'overall'
data = df[df['Task'] == task].copy()

# Sort the dataframe by MCC values in descending order
data = data.sort_values(metric, ascending=False)
data.reset_index(drop = True, inplace = True)

# Set the style and font scale
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.5)

# Set the color palette
custom_palette = sns.color_palette("colorblind", len(data['Task'].unique()))

# Create the bar plot
plt.figure(figsize=(16, 8))
ax = sns.barplot(x=metric, 
                 y='Model', 
                 data=data, 
                 hue = metric,
                 palette=custom_palette,
                 orient='h')

# Add error bars
plt.errorbar(x=data[metric], y=data.index, xerr=data[metric + '_SE']*1.96,
             fmt='none', ecolor='black', capsize=10)

# Add MCC values at the end of bars
for i, v in enumerate(data[metric]):
    ax.text(v+.015, i+.02, f'{v*100:.1f}%'.lstrip('0'), va='center', fontsize=20)

for i, label in enumerate(data['Model']):
    ax.text(0.01, i, label, va='center', ha='left', color='white', fontweight='bold', fontsize=24)

# Customize the plot
plt.xlabel(metric, fontweight='bold', fontsize=24)
plt.ylabel('')  # Remove y-axis label
ax.set_yticklabels([])
ax.set_xticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
# Increase font size of tick marks
plt.tick_params(axis='both', which='major', labelsize=16)

# Remove top and right spines
sns.despine(top=True, right=True)

# Remove vertical grid lines
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.grid(axis='y', linestyle='')
legend = plt.legend()
legend.remove()

# Show the plot
plt.tight_layout()
plt.savefig(r'./figures/figure_1.png'.format(metric), dpi = 300)

############
## Figure 2
############
tasks = ['Topic', 'Stance', 'Events', 'Hatespeech']
metric = 'F1'
data = df[df['Task'].isin(tasks)].copy()  # Create a copy to avoid SettingWithCopyWarning

modelorder = ["DEBATE Large", "DEBATE Base", "Claude 3.5", "Llama 3.3 70B", "NLI Large", "Llama 3.1 8B", "NLI Base"]

data['Model'] = pd.Categorical(data['Model'], categories=modelorder, ordered=True)
data['Task'] = pd.Categorical(data['Task'], categories=tasks, ordered=True)
data = data.sort_values(['Task', 'Model'])
data.reset_index(drop=True, inplace=True)

# Set up the subplot grid
fig, axs = plt.subplots(2, 2, figsize=(16, 12))
axs = axs.flatten()  # Flatten the 2D array of axes for easier indexing

# Get the first color from the colorblind palette
bar_color = sns.color_palette("colorblind")[0]

# Create a bar plot for each task
for i, task in enumerate(tasks):
    task_data = data[data['Task'] == task]
    
    # Create the bar plot
    barplot = sns.barplot(
        x="Model",
        y=metric,
        data=task_data,
        color=bar_color,
        errorbar=None,
        ax=axs[i]
    )
    
    # Add error bars and labels inside bars
    for j, (bar, lower, upper, model) in enumerate(zip(barplot.patches, task_data[metric + '_Lower'], task_data[metric + '_Upper'], task_data['Model'])):
        bar_x = bar.get_x() + bar.get_width() / 2
        bar_height = bar.get_height()
        
        # Add error bars
        axs[i].errorbar(
            x=bar_x,
            y=bar_height,
            yerr=[[(bar_height - lower)*1.96], [(upper - bar_height)*1.96]],
            fmt='none',
            capsize=10,
            color='black'
        )
        
        # Add label inside bar
        axs[i].text(bar_x, 0.01, model, ha='center', va='bottom', 
                    rotation=90, color='white', fontweight='bold', fontsize=20)
    
    # Customize each subplot
    axs[i].set_title(task, fontsize=24, fontweight = 'bold')
    axs[i].set_xlabel('')
    axs[i].set_ylabel(metric if i % 2 == 0 else '', fontsize = 24, fontweight = 'bold')  # Add y-label only for left subplots
    axs[i].tick_params(axis='both', which='major', labelsize=16)
    axs[i].set_xticks([])  # Remove x-axis ticks
    axs[i].set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
    sns.despine(ax=axs[i], top=True, right=True, bottom=False)  # Remove bottom spine
    axs[i].grid(axis='y', which='both', linestyle='--', alpha=0.7)
    
    # Set y-axis limit to 1
    axs[i].set_ylim(0, 1)
    
    # Only show y-ticks for left subplots
    #if i % 2 != 0:
    #    axs[i].set_yticks([])

# Adjust layout
plt.tight_layout()
plt.savefig(r'./figures/figure_2.png'.format(metric), dpi = 300)

############
## Figure 3
############
metric = 'F1'
tasks = ['PoliStance_Affect', 'PoliStance_Affect_QT',
       'acled_event_entailment', 'argument_quality_ranking_entailment',
       'bill_summary_entailment', 'dehumanizing_hatespeech_entailment',
       'dem_rep_party_platform_topics', 'ibm_claimstance_entailment',
       'ibm_claimstance_topic_entailment', 'polistance_issue_tweets',
       'scad_event_entailment', 'targeted_hatespeech_entailment',
       'violent_hatespeech_entailment']

modelorder = ["DEBATE Large", "DEBATE Base", "Claude 3.5", "Llama 3.3 70B", "NLI Large", "Llama 3.1 8B", "NLI Base"]

data = df[df['Task'].isin(tasks)].copy()
data['Model'] = pd.Categorical(data['Model'], categories=modelorder, ordered=True)
bar_color = sns.color_palette("colorblind")[0]

# Set the figure size for better visibility
plt.figure(figsize=(14, 8))

# Create the box plot
sns.boxplot(
    x=metric,           # X-axis
    y='Model',        # Y-axis
    data=data,         # Your DataFrame
    fill=bar_color  # Use a colorblind-friendly palette
)
sns.despine(top=True, right=True)
# Label the axes
plt.xlabel(metric, fontweight = 'bold', size = 24)
plt.ylabel('')
plt.tick_params(axis='y', which='major', labelsize=18)
plt.xticks(ticks = [.5, .6, .7, .8, .9, 1], labels = ['50%', '60%', '70%', '80%', '90%', '100%'])
plt.yticks(fontsize = 24, fontweight = 'bold')
plt.grid(axis='x', which='both', linestyle='--', alpha=0.7)
# Improve layout
plt.tight_layout()
plt.savefig(r'./figures/figure_3.png'.format(metric), dpi = 300)

#######################
## Section 5.1 Numbers
#######################
claude_qt = round(float(df.loc[(df['Model'] == 'Claude 3.5') & (df['Task'] == 'PoliStance_Affect_QT'), 'F1']), 3)
debL_qt = round(float(df.loc[(df['Model'] == 'DEBATE Large') & (df['Task'] == 'PoliStance_Affect_QT'), 'F1']),3)
deb_qt = round(float(df.loc[(df['Model'] == 'DEBATE Base') & (df['Task'] == 'PoliStance_Affect_QT'), 'F1']),3)
print("Claude 3.5 Polistance Quote Tweets F1: {}".format(claude_qt))
print("DEBATE Base Polistance Quote Tweets F1: {}".format(deb_qt))
print("DEBATE Large Polistance Quote Tweets F1: {}".format(debL_qt))

